/*
 * Copyright 2016 LinkedIn, Inc
 *
 * Licensed under the Apache License, Version 2.0 (the "License"); you may not
 * use this file except in compliance with the License. You may obtain a copy of
 * the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations under
 * the License.
 */

package com.linkedin.restli.client;

import java.net.HttpCookie;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.linkedin.common.callback.Callback;
import com.linkedin.data.DataMap;
import com.linkedin.data.schema.PathSpec;
import com.linkedin.data.template.RecordTemplate;
import com.linkedin.parseq.batching.Batch;
import com.linkedin.parseq.batching.BatchImpl.BatchEntry;
import com.linkedin.parseq.function.Tuple3;
import com.linkedin.parseq.function.Tuples;
import com.linkedin.r2.RemoteInvocationException;
import com.linkedin.r2.message.RequestContext;
import com.linkedin.r2.message.rest.RestResponseBuilder;
import com.linkedin.restli.client.response.BatchKVResponse;
import com.linkedin.restli.common.BatchResponse;
import com.linkedin.restli.common.EntityResponse;
import com.linkedin.restli.common.ErrorResponse;
import com.linkedin.restli.common.HttpStatus;
import com.linkedin.restli.common.ProtocolVersion;
import com.linkedin.restli.common.ResourceMethod;
import com.linkedin.restli.common.ResourceSpec;
import com.linkedin.restli.common.RestConstants;
import com.linkedin.restli.internal.client.ResponseImpl;
import com.linkedin.restli.internal.client.response.BatchEntityResponse;
import com.linkedin.restli.internal.common.ProtocolVersionUtil;
import com.linkedin.restli.internal.common.ResponseUtils;

class GetRequestGroup implements RequestGroup {

  private static final Logger LOGGER = LoggerFactory.getLogger(GetRequestGroup.class);
  private static final RestLiResponseException NOT_FOUND_EXCEPTION =
      new RestLiResponseException(new RestResponseBuilder().setStatus(HttpStatus.S_404_NOT_FOUND.getCode()).build(),
          null, new ErrorResponse().setStatus(HttpStatus.S_404_NOT_FOUND.getCode()));

  private final String _baseUriTemplate; //taken from first request, used to differentiate between groups
  private final ResourceSpec _resourceSpec;  //taken from first request
  private final Map<String, String> _headers; //taken from first request, used to differentiate between groups
  private final List<HttpCookie> _cookies; //taken from first request, used to differentiate between groups
  private final RestliRequestOptions _requestOptions; //taken from first request, used to differentiate between groups
  private final Map<String, Object> _queryParams; //taken from first request, used to differentiate between groups
  private final Map<String, Object> _pathKeys; //taken from first request, used to differentiate between groups
  private final int _maxBatchSize;

  @SuppressWarnings("deprecation")
  public GetRequestGroup(Request<?> request, int maxBatchSize) {
    _baseUriTemplate = request.getBaseUriTemplate();
    _headers = request.getHeaders();
    _cookies = request.getCookies();
    _queryParams = getQueryParamsForBatchingKey(request);
    _resourceSpec = request.getResourceSpec();
    _requestOptions = request.getRequestOptions();
    _pathKeys = request.getPathKeys();
    _maxBatchSize = maxBatchSize;
  }

  private static Map<String, Object> getQueryParamsForBatchingKey(Request<?> request)
  {
    final Map<String, Object> params = new HashMap<>(request.getQueryParamsObjects());
    params.remove(RestConstants.QUERY_BATCH_IDS_PARAM);
    params.remove(RestConstants.FIELDS_PARAM);
    return params;
  }

  private static <K, RT extends RecordTemplate> Response<RT> unbatchResponse(BatchGetEntityRequest<K, RT> request,
      Response<BatchKVResponse<K, EntityResponse<RT>>> batchResponse, Object id) throws RemoteInvocationException {
    final BatchKVResponse<K, EntityResponse<RT>> batchEntity = batchResponse.getEntity();
    final ErrorResponse errorResponse = batchEntity.getErrors().get(id);
    if (errorResponse != null) {
      throw new RestLiResponseException(errorResponse);
    }

    final EntityResponse<RT> entityResponse = batchEntity.getResults().get(id);
    if (entityResponse != null) {
      final RT entityResult = entityResponse.getEntity();
      if (entityResult != null) {
        return new ResponseImpl<>(batchResponse, entityResult);
      }
    }

    LOGGER.debug("No result or error for base URI : {}, id: {}. Verify that the batchGet endpoint returns response keys that match batchGet request IDs.",
        request.getBaseUriTemplate(), id);

    throw NOT_FOUND_EXCEPTION;
  }

  private DataMap filterIdsInBatchResult(DataMap data, Set<String> ids) {
    DataMap dm = new DataMap(data.size());
    data.forEach((key, value) -> {
      switch(key) {
        case BatchResponse.ERRORS:
          dm.put(key, filterIds((DataMap)value, ids));
          break;
        case BatchResponse.RESULTS:
          dm.put(key, filterIds((DataMap)value, ids));
          break;
        case BatchResponse.STATUSES:
          dm.put(key, filterIds((DataMap)value, ids));
          break;
        default:
          dm.put(key, value);
          break;
      }
    });
    return dm;
  }

  private Object filterIds(DataMap data, Set<String> ids) {
    DataMap dm = new DataMap(data.size());
    data.forEach((key, value) -> {
      if (ids.contains(key)) {
        dm.put(key, value);
      }
    });
    return dm;
  }


  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceRequests(final Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
      final Request<?> rq) {
    return reduceContainsBatch(reduceIds(reduceFields(state, rq), rq), rq);
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceContainsBatch(Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
      Request<?> request) {
    if (request instanceof GetRequest) {
      return state;
    } else if (request instanceof BatchRequest) {
      return Tuples.tuple(state._1(), state._2(), true);
    } else {
      throw unsupportedGetRequestType(request);
    }
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceIds(Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
      Request<?> request) {
    if (request instanceof GetRequest) {
      GetRequest<?> getRequest = (GetRequest<?>)request;
      state._1().add(getRequest.getObjectId());
      return state;
    } else if (request instanceof BatchRequest) {
      BatchRequest<?> batchRequest = (BatchRequest<?>)request;
      state._1().addAll(batchRequest.getObjectIds());
      return state;
    } else {
      throw unsupportedGetRequestType(request);
    }
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceFields(Tuple3<Set<Object>, Set<PathSpec>, Boolean> state,
      Request<?> request) {
    if (request instanceof GetRequest || request instanceof BatchRequest) {
      final Set<PathSpec> requestFields = request.getFields();
      if (requestFields != null && !requestFields.isEmpty()) {
        if (state._2() != null) {
          state._2().addAll(requestFields);
        }
        return state;
      } else {
        return Tuples.tuple(state._1(), null, state._3());
      }
    } else {
      throw unsupportedGetRequestType(request);
    }
  }

  @SuppressWarnings({ "rawtypes", "unchecked" })
  private <K, RT extends RecordTemplate> void doExecuteBatchGet(final Client client,
    final Batch<RestRequestBatchKey, Response<Object>> batch, final Set<Object> ids, final Set<PathSpec> fields,
    Function<Request<?>, RequestContext> requestContextProvider) {
    final BatchGetEntityRequestBuilder<K, RT> builder = new BatchGetEntityRequestBuilder<>(_baseUriTemplate, _resourceSpec, _requestOptions);
    builder.setHeaders(_headers);
    builder.setCookies(_cookies);
    _queryParams.forEach((key, value) -> builder.setParam(key, value));
    _pathKeys.forEach((key, value) -> builder.pathKey(key, value));

    builder.ids((Set<K>)ids);
    if (fields != null && !fields.isEmpty()) {
      builder.fields(fields.toArray(new PathSpec[fields.size()]));
    }

    final BatchGetEntityRequest<K, RT> batchGet = builder.build();

    client.sendRequest(batchGet, requestContextProvider.apply(batchGet), new Callback<Response<BatchKVResponse<K, EntityResponse<RT>>>>() {

      @Override
      public void onSuccess(Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch) {
        final ProtocolVersion version = ProtocolVersionUtil.extractProtocolVersion(responseToBatch.getHeaders());
        batch.entries().stream()
        .forEach(entry -> {
          try {
            RestRequestBatchKey rrbk = entry.getKey();
            Request request = rrbk.getRequest();
            if (request instanceof GetRequest) {
              successGet((GetRequest) request, responseToBatch, batchGet, entry, version);
            } else if (request instanceof BatchGetKVRequest) {
              successBatchGetKV((BatchGetKVRequest) request, responseToBatch, entry, version);
            } else if (request instanceof BatchGetRequest) {
              successBatchGet((BatchGetRequest) request, responseToBatch, entry, version);
            } else if (request instanceof BatchGetEntityRequest) {
              successBatchGetEntity((BatchGetEntityRequest) request, responseToBatch, entry, version);
            } else {
              entry.getValue().getPromise().fail(unsupportedGetRequestType(request));
            }
          } catch (RemoteInvocationException e) {
            entry.getValue().getPromise().fail(e);
          }
        });
      }

      @SuppressWarnings({ "deprecation" })
      private void successBatchGetEntity(BatchGetEntityRequest request,
          Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch,
          Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry, final ProtocolVersion version) {
        Set<String> ids = (Set<String>) request.getObjectIds().stream()
            .map(o -> BatchResponse.keyToString(o, version))
            .collect(Collectors.toSet());
        DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
        BatchKVResponse br = new BatchEntityResponse<>(dm, request.getResourceSpec().getKeyType(),
            request.getResourceSpec().getValueType(), request.getResourceSpec().getKeyParts(),
            request.getResourceSpec().getComplexKeyType(), version);
        Response rsp = new ResponseImpl(responseToBatch, br);
        entry.getValue().getPromise().done(rsp);
      }

      private void successBatchGet(BatchGetRequest request, Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch,
          Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry, final ProtocolVersion version) {
        Set<String> ids = (Set<String>) request.getObjectIds().stream()
            .map(o -> BatchResponse.keyToString(o, version))
            .collect(Collectors.toSet());
        DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
        BatchResponse br = new BatchResponse<>(dm, request.getResponseDecoder().getEntityClass());
        Response rsp = new ResponseImpl(responseToBatch, br);
        entry.getValue().getPromise().done(rsp);
      }

      @SuppressWarnings({ "deprecation" })
      private void successBatchGetKV(BatchGetKVRequest request, Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch,
          Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry,
          final ProtocolVersion version) {
        Set<String> ids = (Set<String>) request.getObjectIds().stream()
            .map(o -> BatchResponse.keyToString(o, version))
            .collect(Collectors.toSet());
        DataMap dm = filterIdsInBatchResult(responseToBatch.getEntity().data(), ids);
        BatchKVResponse br = new BatchKVResponse(dm, request.getResourceSpec().getKeyType(),
            request.getResourceSpec().getValueType(), request.getResourceSpec().getKeyParts(),
            request.getResourceSpec().getComplexKeyType(), version);
        Response rsp = new ResponseImpl(responseToBatch, br);
        entry.getValue().getPromise().done(rsp);
      }

      @SuppressWarnings({ "deprecation" })
      private void successGet(GetRequest request,
          Response<BatchKVResponse<K, EntityResponse<RT>>> responseToBatch, final BatchGetEntityRequest<K, RT> batchGet,
          Entry<RestRequestBatchKey, BatchEntry<Response<Object>>> entry, final ProtocolVersion version)
              throws RemoteInvocationException {
        String idString = BatchResponse.keyToString(request.getObjectId(), version);
        Object id = ResponseUtils.convertKey(idString, request.getResourceSpec().getKeyType(),
            request.getResourceSpec().getKeyParts(), request.getResourceSpec().getComplexKeyType(), version);
        Response rsp = unbatchResponse(batchGet, responseToBatch, id);
        entry.getValue().getPromise().done(rsp);
      }

      @Override
      public void onError(Throwable e) {
        batch.failAll(e);
      }

    });
  }

  private static RuntimeException unsupportedGetRequestType(Request<?> request) {
    return new RuntimeException("ParSeqRestliClient could not handle this type of GET request: " + request.getClass().getName());
  }

  @SuppressWarnings({ "rawtypes", "unchecked" })
  private <K, RT extends RecordTemplate> void doExecuteGet(final Client client,
      final Batch<RestRequestBatchKey, Response<Object>> batch, final Set<Object> ids, final Set<PathSpec> fields,
      Function<Request<?>, RequestContext> requestContextProvider) {

    final GetRequestBuilder<K, RT> builder = (GetRequestBuilder<K, RT>) new GetRequestBuilder<>(_baseUriTemplate,
        _resourceSpec.getValueClass(), _resourceSpec, _requestOptions);
    builder.setHeaders(_headers);
    builder.setCookies(_cookies);
    _queryParams.forEach((key, value) -> builder.setParam(key, value));
    _pathKeys.forEach((key, value) -> builder.pathKey(key, value));

    builder.id((K) ids.iterator().next());
    if (fields != null && !fields.isEmpty()) {
      builder.fields(fields.toArray(new PathSpec[fields.size()]));
    }

    final GetRequest<RT> get = builder.build();

    client.sendRequest(get, requestContextProvider.apply(get), new Callback<Response<RT>>() {

      @Override
      public void onError(Throwable e) {
        batch.failAll(e);
      }

      @Override
      public void onSuccess(Response<RT> responseToGet) {
        batch.entries().stream().forEach(entry -> {
          Request request = entry.getKey().getRequest();
          if (request instanceof GetRequest) {
            entry.getValue().getPromise().done(new ResponseImpl<>(responseToGet, responseToGet.getEntity()));
          } else {
            entry.getValue().getPromise().fail(unsupportedGetRequestType(request));
          }
        });
      }

    });
  }

  //Tuple3: (keys, fields, contains-batch-get)
  private Tuple3<Set<Object>, Set<PathSpec>, Boolean> reduceRequests(
      final Batch<RestRequestBatchKey, Response<Object>> batch) {
    return batch.entries().stream()
      .map(Entry::getKey)
      .map(RestRequestBatchKey::getRequest)
      .reduce(Tuples.tuple(new HashSet<>(), new HashSet<>(), false),
          GetRequestGroup::reduceRequests,
          GetRequestGroup::combine);
  }

  private static Tuple3<Set<Object>, Set<PathSpec>, Boolean> combine(Tuple3<Set<Object>, Set<PathSpec>, Boolean> a,
      Tuple3<Set<Object>, Set<PathSpec>, Boolean> b) {
    Set<Object> ids = a._1();
    ids.addAll(b._1());
    Set<PathSpec> paths = a._2();
    paths.addAll(b._2());
    return Tuples.tuple(ids, paths, a._3() || b._3());
  }

  @Override
  public <RT extends RecordTemplate> void executeBatch(final Client client, final Batch<RestRequestBatchKey, Response<Object>> batch,
      Function<Request<?>, RequestContext> requestContextProvider) {
    final Tuple3<Set<Object>, Set<PathSpec>, Boolean> reductionResults = reduceRequests(batch);
    final Set<Object> ids = reductionResults._1();
    final Set<PathSpec> fields = reductionResults._2();
    final boolean containsBatchGet = reductionResults._3();

    LOGGER.debug("executeBatch, ids: '{}', fields: {}", ids, fields);

    if (ids.size() == 1 && !containsBatchGet) {
      doExecuteGet(client, batch, ids, fields, requestContextProvider);
    } else {
      doExecuteBatchGet(client, batch, ids, fields, requestContextProvider);
    }
  }

  @Override
  public String getBaseUriTemplate() {
    return _baseUriTemplate;
  }

  public Map<String, String> getHeaders() {
    return _headers;
  }

  public List<HttpCookie> getCookies() {
    return _cookies;
  }

  public Map<String, Object> getQueryParams() {
    return _queryParams;
  }

  public Map<String, Object> getPathKeys() {
    return _pathKeys;
  }

  public ResourceSpec getResourceSpec() {
    return _resourceSpec;
  }

  public RestliRequestOptions getRequestOptions() {
    return _requestOptions;
  }


  @Override
  public int hashCode() {
    final int prime = 31;
    int result = 1;
    result = prime * result + Objects.hashCode(_baseUriTemplate);
    result = prime * result + Objects.hashCode(_headers);
    result = prime * result + Objects.hashCode(_cookies);
    result = prime * result + Objects.hashCode(_queryParams);
    result = prime * result + Objects.hashCode(_pathKeys);
    result = prime * result + Objects.hashCode(_requestOptions);
    return result;
  }

  @Override
  public boolean equals(Object obj) {
    if (this == obj)
      return true;
    if (obj == null)
      return false;
    if (getClass() != obj.getClass())
      return false;
    GetRequestGroup other = (GetRequestGroup) obj;

    if (_resourceSpec == null){
      if (other._resourceSpec != null) {
        return false;
      }
    } else if (_resourceSpec.getKeyClass() != other._resourceSpec.getKeyClass()) {
      return false;
    }

    return Objects.equals(_baseUriTemplate, other._baseUriTemplate)
        && Objects.equals(_headers, other._headers)
        && Objects.equals(_cookies, other._cookies)
        && Objects.equals(_queryParams, other._queryParams)
        && Objects.equals(_pathKeys, other._pathKeys)
        && Objects.equals(_requestOptions, other._requestOptions);
  }

  @Override
  public String toString() {
    return "GetRequestGroup [_baseUriTemplate=" + _baseUriTemplate + ", _queryParams=" + _queryParams + ", _pathKeys=" + _pathKeys
        + ", _requestOptions=" + _requestOptions + ", _headers=" + _headers + ", _cookies=" + _cookies
        + ", _maxBatchSize=" + _maxBatchSize + "]";
  }

  @Override
  public <K, V> String getBatchName(final Batch<K, V> batch) {
    return _baseUriTemplate + " " + (batch.batchSize() == 1 ? ResourceMethod.GET : (ResourceMethod.BATCH_GET +
        "(reqs: " + batch.keySize() + ", ids: " + batch.batchSize() + ")"));
  }

  @Override
  public int getMaxBatchSize() {
    return _maxBatchSize;
  }

  @Override
  public int keySize(RestRequestBatchKey key) {
    return key.ids().size();
  }

}
